import numpy as np
import scipy.stats as stats
from scipy.optimize import minimize
import pandas as pd
from sklearn import linear_model
from random import choices
import statsmodels.api as sm
from statsmodels.genmod.families import links
from statsmodels.genmod.generalized_linear_model import GLM
from statsmodels.genmod.families import Binomial
from statsmodels.genmod import families
from scipy.stats import multivariate_normal, uniform, mode
from scipy.special import logsumexp
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
import pickle
import os
import concurrent.futures
import sys

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# Sigmoid function
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# Feature mapping
def phi(s, a, a_No=2):
    d=s.shape[1]
    a = np.array(a).reshape((s.shape[0], 1))
    s=s/np.linalg.norm(s, axis=1)[:,None]
    result = np.zeros((s.shape[0], d * a_No))
    indices = np.arange(s.shape[0])[:, None] * (d * a_No) + a * d + np.arange(d)

    result.flat[indices.ravel()] = s.ravel()
    return result

# Reward Function
def get_reward(phi, th):

    r_means = sigmoid(phi.dot(th)) 

    # Generate rewards from binomial distribution
    return np.random.binomial(1,r_means)

def reject_accept(phi, a,dim):
    while True:
        next_s = np.random.uniform(0,1,(dim))
        alpha=np.min((1,(phi*(a+1)+a/dim).dot(np.exp(-next_s))/(np.sum(next_s)*(a+1)+a)))
        u=np.random.uniform(0,1)
        if u<=alpha:
            break
    return next_s
            
def next_state(ph1,action):
    dim=ph1.shape[1]
    action=action+1
    def wrapper(phi_a,dim):
        phi, a = phi_a[:-1], phi_a[-1]
        return reject_accept(phi, a, dim)
    
    combined = np.hstack((ph1, action.reshape(-1, 1)))
    result = np.array([wrapper(row,dim) for row in combined])
    
    return {'state': result}

def get_prob(H,a_No):

    prob=np.random.uniform(0.3, 0.7, (H, a_No))
    row_sums = prob.sum(axis=1)[:, np.newaxis]
    normalized_prob = prob / row_sums
    return normalized_prob
    
def random_policy(prob,h,a_No,episodes_No):
    p=prob[h]
    choices = np.random.choice(range(a_No), size=episodes_No, p=p)
    return choices

def MDP(dim, a_No, H, episodes_No,prob=None, seed=None, policy=None, beta_hat_h=None,theta_hat=None, theta_var= None,Lamda=None, beta=None, beta0=None):
    
    np.random.seed(seed)
    theta_h = np.random.uniform(0, 1, (dim*a_No, H))

    Ps0 = np.random.uniform(0 ,1, (episodes_No, dim))
    states = {0: {'state':Ps0}}
    rewards, actions, phis_d = {}, {}, {}
    
    for h in range(H): 
        
        if policy is None:
            actions[h] = np.random.choice(range(a_No), size=episodes_No)

        elif policy=="random":
            actions[h]=random_policy(prob,h,a_No,episodes_No)

        elif policy=="optimal":
            actions[h]=optimal(states,h,a_No,theta_h,episodes_No)

        elif policy=="linear":
            actions[h],states[h]["V_hat"]= fPEVI_hat_linear(states,H,h,a_No,Lamda,beta_hat_h,beta)

        else:
            actions[h] = policy(h, states)
            
        rewards[h] = get_reward(phi(states[h]["state"],actions[h],a_No), theta_h[:,h])
        states[h + 1] = next_state(states[h]["state"],actions[h])

    del states[H]
    V = np.sum([np.mean(rewards[h]) for h in range(H)])
    return (V, states, actions, rewards, theta_h)


def MDP2(dim, a_No, H, episodes_No,policy=None, seed=None,beta_hat_h=None,theta_hat=None, Lamda=None, theta_var= None,beta=None, beta0=None):
    
    np.random.seed(seed)
    theta_h = np.random.uniform(0, 1, (dim*a_No, H))

    hospitals={}

    hospitals["actions"]={}
    hospitals["rewards"]={}
    hospitals["theta_true"]=theta_h
    hospitals["states"]={}
    for h in range(H+1):
        hospitals["states"][h]={}


    Ps0 = np.random.uniform(0,1,(episodes_No, dim))
    hospitals["states"][0]["state"]= Ps0
        
    for h in tqdm(range(H)): 
        
        states=np.vstack([hospitals["states"][h]["state"]])

        if policy=="glm":
            # GLM
            pi_hat,_=fPEVI_hat_glm(theta_hat,theta_var,states,H,h,a_No,Lamda[h],beta_hat_h,beta,beta0)
            hospitals["actions"][h]=pi_hat
            hospitals["rewards"][h]=get_reward(phi(hospitals["states"][h]["state"],hospitals["actions"][h], a_No), theta_h[:,h])
                
        if policy is None:
            hospitals["actions"][h]=choices(range(a_No), k=episodes_No)
                
        if policy != "glm":
            phi_d=phi(states,hospitals["actions"][h], a_No)
            hospitals["rewards"][h]=get_reward(phi_d, theta_h[:,h])
            
        hospitals["states"][h+1]=next_state(states,hospitals["actions"][h])
    V=np.sum([np.mean(hospitals["rewards"][h]) for h in range(H)])

    return V, hospitals

def optimal(states,h,a_No,theta_h,episodes_No):
    theta=theta_h[:,h]
    Qhat_sa  = []
    for a in range(a_No):
        Z_a=phi(s=states[h]['state'],a=np.tile(a,episodes_No), a_No=a_No)
        Qhat_sa.append(Z_a@theta)
    return np.argmax(np.vstack(Qhat_sa), axis=0)

def Gamma0(Z,theta_value,beta0):
    return beta0*np.sqrt(Z.dot(theta_value).dot(Z))

def Gamma(Z,L,beta):
    return beta*np.sqrt(Z.dot(np.linalg.inv(L)).dot(Z))

def Qbar(Z,betahat,thetahat,GammVal,GammVal0):    
    return Z.dot(betahat)+sigmoid(Z.dot(thetahat))-GammVal-GammVal0

def Qbar_linear(Z,betahat,GammVal):    
    return Z.dot(betahat) - GammVal

def Qhat(QbarVal,H,h):    
    val = min(QbarVal,H-h)
    return max(val,0)

def cal_grad(x, theta):
    z=x.dot(theta)
    return sigmoid(z)*(1-sigmoid(z))
        

# Linear Regression Model

def fPEVI_hat_linear(states,H,h,a_No,Lam,beta_h_hat,beta):
    n = states[h]['state'].shape[0]
    beta_hat_h=beta_h_hat[h]
    Qhat_sa  = []
    for a in range(a_No):
        Z_a=phi(s=states[h]['state'],a=np.tile(a,n), a_No=a_No)
        GammValues = [Gamma(Z=Z_a[i],L=Lam[h],beta=beta) for i in range(n)]    
        Qbar_sa = [Qbar_linear(Z=Z_a[i],betahat=beta_hat_h,GammVal=GammValues[i]) for i in range(n)]
        Qhat_sa.append([Qhat(QbarVal=Qbar_sa[i],H=H,h=h) for i in range(n)])
    pi_hat  = np.argmax(np.array(Qhat_sa),0)
    V_hat  = np.max(np.array(Qhat_sa),0)
    return pi_hat,V_hat

def L_linear(hospitals,h,d0,lamda=1, a_No=2):
    
    X = phi(s=hospitals['states'][h]['state'],a=hospitals['actions'][h], a_No=a_No)
    
    return np.dot(X.T,X)+lamda*np.eye(d0)

def Linear(hospitals,a_No,d0,c=0.01):

    n= len(hospitals['actions'][0])
    H = len(hospitals['actions'])
    lamda,xi,d = 1, 0.95, d0
    zeta = np.log(2*d*H*n/xi)
    beta = c*d*H*np.sqrt(zeta)
    
    fed_Vhat = {H:np.zeros(n)}

    beta_h_hat={}
    Lam={}
    
    for h in range(H-1,-1,-1):
        X = phi(s=hospitals['states'][h]['state'],a=hospitals['actions'][h], a_No=a_No)
        y = hospitals['rewards'][h] +fed_Vhat[h+1]
        Lam[h] = L_linear(hospitals,h,d0, a_No=a_No)
        beta_h_hat[h]=(np.linalg.inv(Lam[h]))@X.T.dot(y)
        hospitals['states'][h]["pi"], hospitals['states'][h]["V_hat"]=fPEVI_hat_linear(hospitals['states'],H,h,a_No,Lam,beta_h_hat,beta)
        fed_Vhat[h]=hospitals['states'][h]["V_hat"]
            
    return beta_h_hat, beta, Lam

def est_theta_glm(hospitals, a_No):
    theta_hat={}
    H = len(hospitals['actions'])
    for h in range(H):
        X = phi(s=hospitals['states'][h]['state'], a=hospitals['actions'][h], a_No=a_No)
    
        rewards = hospitals['rewards'][h]
        model = sm.GLM(rewards, X, 
                      family=sm.families.Binomial(link=sm.families.links.Logit()))
        result = model.fit()
        theta_hat[h]=result.params

    return theta_hat

def est_theta_var_glm(hospitals,theta_hat, a_No):
    theta_var={}
    H = len(hospitals['actions'])
    for h in range(H):
        X = phi(s=hospitals['states'][h]['state'],a=hospitals['actions'][h], a_No=a_No)
        theta_var[h]=theta_variance_glm(X,theta_hat[h])
    return theta_var

def theta_variance_glm(X, theta_hat):
   
    X_dot_theta = X.dot(theta_hat)
    sigmoid_values = sigmoid(X_dot_theta)
    
    diag=sigmoid_values * (1 - sigmoid_values)
    return (X.T * diag).dot(X)

def L_glm(hospitals,h,d0,lamda=1, a_No=2):
    
    X = phi(s=hospitals['states'][h]['state'],a=hospitals['actions'][h], a_No=a_No)
    return np.dot(X.T,X)+lamda*np.eye(d0)

def fPEVI_hat_glm(theta_hat,theta_var,states,H,h,a_No,Lamda,beta_h_hat,beta,beta0):
    n = states.shape[0]
    theta_hat_h=theta_hat[h]
    beta_hat_h=beta_h_hat[h]
    Qhat_sa  = []
    theta_value=np.linalg.pinv(theta_var[h])
    for a in range(a_No):
        Z_a=phi(s=states,a=np.tile(a,n), a_No=a_No)
        GammValues0= [cal_grad(x=Z_a[i], theta=theta_hat_h)*Gamma0(Z=Z_a[i],theta_value=theta_value,beta0=beta0) for i in range(n)] 
        GammValues = [Gamma(Z=Z_a[i],L=Lamda,beta=beta) for i in range(n)]    
        Qbar_sa = [Qbar(Z=Z_a[i],betahat=beta_hat_h,thetahat=theta_hat_h,GammVal=GammValues[i],GammVal0=GammValues0[i]) for i in range(n)]
        Qhat_sa.append([Qhat(QbarVal=Qbar_sa[i],H=H,h=h) for i in range(n)])
    pi_hat  = np.argmax(np.array(Qhat_sa),0)
    V_hat  = np.max(np.array(Qhat_sa),0)                    
    return pi_hat,V_hat

def GLM(hospitals,a_No,d0,episodes_No,c=0.005):

    n = len(hospitals['actions'][0])
    H = len(hospitals['actions'])
    lamda,xi,d = 1, 0.95 , d0
    zeta=np.log(2*d*H*n/xi)
    
    beta = c*d*H*np.sqrt(zeta)
    beta0=c*np.sqrt(d*np.log(H/xi))
    
    theta_hat=est_theta_glm(hospitals, a_No)
    theta_var=est_theta_var_glm(hospitals,theta_hat, a_No)
    
    fed_Vhat = {H:np.zeros(n)}
    V_k = {H:np.zeros(n)}
    beta_h_hat={}
    Lam={}
    for h in range(H-1,-1,-1):
        
        X = phi(s=hospitals['states'][h]['state'],a=hospitals['actions'][h], a_No=a_No)
        y = V_k[h+1]
        Lamda=L_glm(hospitals,h,d0, a_No=a_No)
        Lam[h]=Lamda
        beta_h_hat[h]=(np.linalg.inv(X.T@X))@X.T.dot(y)
        states=np.vstack([hospitals["states"][h]["state"]])
        pi_hat,V_hat=fPEVI_hat_glm(theta_hat,theta_var,states,H,h,a_No,Lamda,beta_h_hat,beta,beta0)
        V_k[h]=V_hat
    return beta_h_hat, beta, beta0, Lam, theta_hat,theta_var

def estimate_local_Vs(hospitals,a_No):
    H = len(hospitals['actions'])
    n = len(hospitals['actions'][0])
    V_hat,Qreg = {H:np.zeros(n)},{}
    for h in range(H-1,-1,-1):            
        X=phi(s=hospitals['states'][h]['state'],a=hospitals['actions'][h], a_No=a_No)
        y = hospitals['rewards'][h] + V_hat[h+1]
        reg = linear_model.LinearRegression()
        reg.fit(X, y)
        Qreg[h] = reg
        Q_sa = []
        for a in range(a_No):
            X_a = phi(s=hospitals['states'][h]['state'],a=np.tile(a,n), a_No=a_No)
            Q_sa.append(list(reg.predict(X_a)))

            V_hat[h] =np.max(np.array(Q_sa).T,1)
    hospitals['V_hats'] = V_hat
    hospitals['Qregs'] = Qreg
    return hospitals

def train_Qlearn0(hospitals,H,a_No):

    Ys = np.hstack([hospitals['V_hats'][h] for h in range(H)])
    Xs = np.vstack([phi(s=hospitals['states'][h]['state'],a=hospitals['actions'][h], a_No=a_No).tolist() for h in range(H)])
    reg = linear_model.LinearRegression()
    reg.fit(Xs, Ys)
    return reg

def simulations(hospitals,dim,a_No,H,episodes_No,test_No,seed,c):
    
    d0 = a_No*dim
    beta_h_hat_linear, beta_linear, Lam_linear=Linear(hospitals,a_No,d0,c) # linear 
    beta_h_hat_glm, beta_glm, beta0_glm, Lam_glm,theta_hat,theta_var=GLM(hospitals,a_No,d0,episodes_No,c) # pooled
    V_linear=[]
    
    # Linear Single
    V_linear,_,_,_,_ = MDP(dim,a_No,H,test_No,prob=None,seed=seed,policy="linear",beta_hat_h=beta_h_hat_linear,Lamda=Lam_linear,beta=beta_linear)
        
    # GLM
    V_glm,_ = MDP2(dim,a_No,H,test_No,policy="glm",seed=seed,beta_hat_h=beta_h_hat_glm,theta_hat=theta_hat,theta_var=theta_var,Lamda=Lam_glm,beta=beta_glm,beta0=beta0_glm)
 
    return np.array([np.mean(V_linear), np.mean(V_glm)])

def simulations_Q(hospital,Qlearn0,dim,a_No,H,episodes_No,seed):

    V_sq=[]
    V_l=[]

    def single_Qlearn(h,states,Qlearn0=Qlearn0, a_No=a_No):
        n_k = len(states[h]['state'])
        # Calculate The Q Value
        Q_sa = []
        for a in range(a_No):
            X_a = phi(s=states[h]['state'],a=np.tile(a,n_k), a_No=a_No)
            Q_sa.append(list(Qlearn0.predict(X_a)))
        return np.argmax(np.array(Q_sa),0)
    
    def local_Qlearn(h,states,hospital=hospital, a_No=a_No):
        Qfun = hospital['Qregs']
        n_k = len(states[h]['state'])
         # Calculate The Q Value
        Q_sa = []
        for a in range(a_No):
            X_a = phi(s=states[h]['state'],a=np.tile(a,n_k), a_No=a_No)
            Q_sa.append(list(Qfun[h].predict(X_a)))
        return np.argmax(np.array(Q_sa),0)
    
    # Single Qlearn
    Vq,_,_,_,_ = MDP(dim,a_No,H,episodes_No,prob=None,seed=seed, policy=single_Qlearn)
    V_sq.append(Vq)

    # Linear Qlearn
    Vl,_,_,_,_ = MDP(dim,a_No,H,episodes_No,prob=None,seed=seed, policy=local_Qlearn)
    V_l.append(Vl)

    return np.array([np.mean(V_sq),np.mean(V_l)])
        
def worker(params):
    return simulations(*params)

def simu(dim, a_No, H, episodes_No, seed, C, method, test_No):

    
    hospitals = {}
    
    print("---Generating Data---")
    
    if method == "random":
        prob = get_prob(H, a_No)
        _, states, actions, rewards, theta = MDP(dim, a_No, H, episodes_No, prob, seed=seed, policy="random" )
        hospitals = {'states': states, 'actions': actions, 'rewards': rewards, "theta_true": theta}
        
    elif method == "optimal":
        prob = get_prob(H, a_No)
        _, states, actions, rewards, theta = MDP(dim, a_No, H, episodes_No, prob, seed=seed, policy="optimal")
        hospitals = {'states': states, 'actions': actions, 'rewards': rewards, "theta_true": theta}
                
    elif method is None:
        prob = get_prob(H, a_No)
        _, states, actions, rewards, theta = MDP(dim, a_No, H, episodes_No, prob, seed=seed, policy=None)
        hospitals = {'states': states, 'actions': actions, 'rewards': rewards, "theta_true": theta}
    
    d0 = a_No*dim
    result = []
    
    print("---Estimating Local Vs---")
    hospital = estimate_local_Vs(hospitals, a_No)
    Qlearn0 = train_Qlearn0(hospital, H, a_No=a_No)
    
    result = np.zeros((len(C), 2))
    
    print("---Parallel Computing---")
    # Para List
    params_list = [(hospitals, dim, a_No, H, episodes_No, test_No, seed, c) for c in C]
    
    # Parallel Computing
    with Pool() as p:
        results = p.map(worker, params_list)
    
    print("---Merging Results---")
    # Merge The Parallel Data The Results Are [linear, glm]
    for i, res in enumerate(results):
        result[i] = res
    
    print("---Storing Qlearn0 Results---")
    # Store The Qlearn0 Results [Single Q, Local Linear]
    result = np.vstack((result, simulations_Q(hospital, Qlearn0, dim, a_No, H, test_No, seed)))

    return result

import argparse
def parse_arguments():
    parser = argparse.ArgumentParser(description="Run the simu function with given parameters.")
    
    parser.add_argument('--dim', type=int, default=10, help='Dimension')
    parser.add_argument('--a_No', type=int, default=2, help='a_No parameter')
    parser.add_argument('--H', type=int, default=10, help='H parameter')
    parser.add_argument('--seed', type=int, default=121, help='Seed for random generation')
    parser.add_argument('--C', type=float, nargs='+', default=[0, 0.005, 0.001, 0.0005, 0.0001], help='C parameter list')
    parser.add_argument('--method', type=str, default=None, choices=['random', 'optimal', 'None'], help='Method to be used for simulation')
    parser.add_argument('--test_No', type=int, default=500, help='Test set size')
    parser.add_argument('--episodes_No', type=int, default=200, help='Episode number for training')
    parser.add_argument('--save_path', type=str, default="data/", help='Save path')
    parser.add_argument('--simulations', type=int, default=2, help='Number of simulations')
    
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parse_arguments()
    
    dim = args.dim  # Dimension
    a_No = args.a_No  # Action Number
    H = args.H  # episode length for a single patient
    seed = args.seed  # random seed
    C = args.C  # C parameter list
    method = args.method  # data generation method
    test_No = args.test_No  # test size
    save_path = args.save_path  # save path
    episodes_No = args.episodes_No  # train size
    simulations_times = args.simulations  # number of simulations
    print("Start Simulation:")
    print(args)

    # Calculate the Mean and Standard Deviation of the Results
    results = []
    for i in range(simulations_times):
        # Merge by row
        results.append(simu(dim, a_No, H, episodes_No, seed, C, method, test_No))
    
    # The Structure of Results is [[linear, glm], [linear, glm], ...]
    results = np.array(results)

    mean = np.mean(results, axis=0) # 
    std = np.std(results, axis=0)

    # Store The Results
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    mean_output_path = os.path.join(save_path, f"{dim}_{a_No}_{H}_{episodes_No}_{test_No}_{seed}_{simulations_times}_mean.csv")
    std_output_path = os.path.join(save_path, f"{dim}_{a_No}_{H}_{episodes_No}_{test_No}_{seed}_{simulations_times}_std.csv")

    mean = pd.DataFrame(mean)  # linear, glm
    std = pd.DataFrame(std)  # linear, glm

    mean.columns = ["linear", "glm"]
    mean.index = [f"C={c}" for c in C] + ["Qlearning"]

    mean.to_csv(mean_output_path)

    std.columns = ["linear", "glm"]
    std.index = [f"C={c}" for c in C] + ["Qlearning"]

    std.to_csv(std_output_path)


